{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "922df148-5712-4df5-ba57-e2b712511e36",
   "metadata": {},
   "source": [
    "# SMERF Demo\n",
    "\n",
    "This notebook demonstrates how to train and render a (prebaked) SMERF model. \n",
    "It should produce identical results to `scripts/demo.sh`.\n",
    "\n",
    "The notebook requires a CUDA-capable GPU with at least 12 GB of VRAM such as an RTX 3080 Ti. \n",
    "See `README.md` for instructions.\n",
    "\n",
    "Licensed under the Apache License, Version 2.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "715c68d0-b41f-42fc-83bd-06ad83c10ffa",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "This section imports necessary libraries and initializes configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc639bb8-0143-43b2-b439-dd723c5a1500",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Imports\n",
    "import datetime\n",
    "import functools\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import camp_zipnerf.internal\n",
    "from etils import ecolab\n",
    "from etils import epath\n",
    "import flax\n",
    "from flax.training import checkpoints\n",
    "import gin\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import mediapy as media\n",
    "import numpy as np\n",
    "import pycolmap\n",
    "import smerf.internal\n",
    "\n",
    "gin.enter_interactive_mode()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03a9e0f7-50b7-44f0-af77-2848971c1b66",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Setup config\n",
    "\n",
    "# Create a timestamp for this execution\n",
    "TIMESTAMP = datetime.datetime.utcnow().strftime('%Y%m%d_%H%M')\n",
    "\n",
    "# Setup Gin-related flags\n",
    "try:\n",
    "  camp_zipnerf.internal.configs.define_common_flags()\n",
    "  # Jupyter will set the working directory to notebooks/. This sets it to the project root.\n",
    "  os.chdir('../')\n",
    "except Exception as err:\n",
    "  print(err)\n",
    "\n",
    "flags.FLAGS.gin_configs = [\n",
    "    'configs/models/smerf.gin',\n",
    "    'configs/mipnerf360/bicycle.gin',\n",
    "    'configs/mipnerf360/extras.gin',\n",
    "    'configs/mipnerf360/rtx3080ti.gin',\n",
    "]\n",
    "flags.FLAGS.gin_bindings = f\"\"\"\n",
    "smerf.internal.configs.Config.checkpoint_dir = 'checkpoints/{TIMESTAMP}-notebook'\n",
    "\"\"\"\n",
    "\n",
    "# Parse gin configs from command line flags. There is no direct way to pass Gin configs to load_config().\n",
    "app._run_init(sys.argv[:1], flags.FLAGS)\n",
    "\n",
    "# Parse Gin configs\n",
    "gin.clear_config()\n",
    "smerf_config, config = smerf.internal.distill.load_config(False)\n",
    "\n",
    "print(f'{TIMESTAMP=}')\n",
    "print(f'{smerf_config.checkpoint_dir=}')\n",
    "print(f'{smerf_config.baking_checkpoint_dir=}')\n",
    "print(f'{smerf_config.distill_teacher_ckpt_dir=}')\n",
    "print(f'{config.data_dir=}')\n",
    "print(f'{config.batch_size=}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5265e41d-931f-4bb2-821c-20d0d778e8d5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Print Gin config\n",
    "print(gin.config_str())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dc5625d-859d-4c7e-baa9-34a6f85aae78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Print GPU RAM usage\n",
    "def memory_stats(device):\n",
    "  stats = device.memory_stats()\n",
    "  gb_in_use = stats['bytes_in_use'] / 2 ** 30\n",
    "  gb_available = stats['bytes_limit'] / 2 ** 30\n",
    "  percent_bytes_in_use = 100 * gb_in_use / gb_available\n",
    "  return f\"GPU_{device.id}: {gb_in_use:0.2f} GiB of {gb_available:0.2f} GiB ({percent_bytes_in_use:0.1f}%)\"\n",
    "\n",
    "jax.tree_map(memory_stats, jax.devices())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3ed0afb-00b3-4051-8aed-ff99dc292101",
   "metadata": {},
   "source": [
    "## Dataset\n",
    "\n",
    "This section loads the training dataset into memory.\n",
    "Make sure that the `config.data_dir` points to a directory with a single scene's worth of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f56cc293-b4d9-47ea-8589-351cc4abe1c8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Load Train\n",
    "dataset = camp_zipnerf.internal.datasets.load_dataset('train', config.data_dir, config)\n",
    "raybatcher = camp_zipnerf.internal.datasets.RayBatcher(dataset)\n",
    "dataset;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fed07fc-bd2d-47bb-8b1c-98f32ebc4d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Construct camera parameters\n",
    "np_to_jax = lambda x: jnp.array(x) if isinstance(x, np.ndarray) else x\n",
    "cameras = dataset.get_train_cameras(config)\n",
    "cameras = jax.tree_util.tree_map(np_to_jax, cameras)\n",
    "pcameras = flax.jax_utils.replicate(cameras)\n",
    "pcameras;s  # intrinsics, extrinsics, ..."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a8f22ed-b1e8-4aaa-a77f-79b1f6df63b2",
   "metadata": {},
   "source": [
    "## Teacher\n",
    "\n",
    "This section initializes the teacher model. \n",
    "Make sure that `config.checkpoint_dir` points to a directory with a pretrained `camp_zipnerf` model checkpoint.\n",
    "\n",
    "To verify that the checkpoint loaded successfully, a single training camera is rendered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d28705eb-b223-4262-88c7-ebae8ed7d2d2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Setup Teacher\n",
    "rng = jax.random.PRNGKey(config.jax_rng_seed)\n",
    "teacher_model, teacher_state, teacher_render_eval_pfn, _, _ = camp_zipnerf.internal.train_utils.setup_model(\n",
    "    config, rng, dataset=dataset\n",
    ")\n",
    "teacher_state.params;s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cd8c959-01df-4a61-94ea-150599181472",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Reload teacher's state from disk\n",
    "teacher_state = checkpoints.restore_checkpoint(config.checkpoint_dir, teacher_state)\n",
    "print('step:', teacher_state.step)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d47dabe-3e4a-4949-8837-f056246cc70d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Replicate state across devices.\n",
    "teacher_pstate = flax.jax_utils.replicate(teacher_state)\n",
    "teacher_pvariables = teacher_pstate.params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7cc6b91-8fa6-486f-822f-19418124ffd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Render a frame.\n",
    "\n",
    "CAM_IDXS = [0]  # Which train images to render.\n",
    "\n",
    "def main():\n",
    "  for cam_idx in CAM_IDXS:\n",
    "    # Prepare rays\n",
    "    rays = smerf.internal.datasets.cam_to_rays(dataset, cam_idx, xnp=jnp)\n",
    "\n",
    "    # Render rays\n",
    "    start = time.time()\n",
    "    teacher_rendering = camp_zipnerf.internal.models.render_image(\n",
    "        functools.partial(\n",
    "            teacher_render_eval_pfn,\n",
    "            teacher_pstate.params,\n",
    "            1.0,\n",
    "            None,  # No cameras needed\n",
    "        ),\n",
    "        rays=rays,\n",
    "        rng=None,\n",
    "        config=config,\n",
    "        return_all_levels=True,\n",
    "    )\n",
    "    teacher_rendering = jax.device_get(teacher_rendering)\n",
    "\n",
    "    # Print elapsed time\n",
    "    end = time.time()\n",
    "    elapsed = end - start\n",
    "    print(f'Elapsed time: {elapsed:0.2f}')\n",
    "    print(f'Resolution: {teacher_rendering[\"rgb\"].shape}')\n",
    "\n",
    "    media.show_images(\n",
    "        {'gt': dataset.images[cam_idx], 'teacher': teacher_rendering['rgb']},\n",
    "        ylabel=f'{cam_idx=}',\n",
    "        width=800,\n",
    "    )\n",
    "\n",
    "\n",
    "main()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7281980a-cbfd-4235-9fd2-a35f9848ef13",
   "metadata": {},
   "source": [
    "## SMERF"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da7b95a6-9678-4d57-a275-92abf47c9f6a",
   "metadata": {},
   "source": [
    "### Setup\n",
    "\n",
    "This section loads the test dataset, finalizes SMERF's config, and initializes the SMERF model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "392ad6e3-221a-4595-9292-99e9ffd562e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Load test dataset\n",
    "test_dataset = camp_zipnerf.internal.datasets.load_dataset('test', config.data_dir, config)\n",
    "test_raybatcher = camp_zipnerf.internal.datasets.RayBatcher(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90bd644b-dfdc-489b-b4e7-0ee5e8373da4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Initialize grid_config\n",
    "smerf_config = smerf.internal.grid_utils.initialize_grid_config(\n",
    "    smerf_config, [dataset, test_dataset]\n",
    ")\n",
    "hash(smerf_config)  # Make sure hashing works\n",
    "\n",
    "{\n",
    "    'grid_config': smerf_config.grid_config,\n",
    "    'exposure_config': smerf_config.exposure_config,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00e976f4-d837-4b65-91bd-a3ea523a5582",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Initialize model\n",
    "model, state, _, train_pstep, _ = smerf.internal.train_utils.setup_model(\n",
    "    smerf_config, jax.random.PRNGKey(smerf_config.model_seed), dataset\n",
    ")\n",
    "smerf_render_eval_pfn = smerf.internal.distill.create_prender_student(\n",
    "    teacher_model=teacher_model,\n",
    "    student_model=model,\n",
    "    merf_config=smerf_config,\n",
    "    alpha_threshold=smerf.internal.baking.final_alpha_threshold(smerf_config),\n",
    "    return_ray_results=True,\n",
    ")\n",
    "state.params;s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6fee59-db5e-45c1-add2-16cf0090e474",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @title Replicate state across devices\n",
    "pstate = flax.jax_utils.replicate(state)\n",
    "pstate.params;s"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e779c3d3-700b-4e34-9423-122efcf5c13d",
   "metadata": {},
   "source": [
    "### Train\n",
    "\n",
    "This is the main training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b4313d4-45b9-471c-8a30-c5c41494bedb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Library\n",
    "\n",
    "# Function for generating teacher supervision.\n",
    "prender_teacher = smerf.internal.distill.create_prender_teacher(teacher_model, config)\n",
    "\n",
    "\n",
    "def mse_to_psnr(v):\n",
    "  return -10 * np.log10(v) if v > 0 else 0.0\n",
    "    \n",
    "\n",
    "def render_example(dataset, cam_idx, teacher_pstate, smerf_pstate):\n",
    "  \"\"\"Renders a single camera with the teacher and student.\"\"\"\n",
    "\n",
    "  # Construct camera rays\n",
    "  assert (\n",
    "      0 <= cam_idx < len(dataset.images)\n",
    "  ), f'{cam_idx=} is not in this dataset.'\n",
    "  val_rays = smerf.internal.datasets.cam_to_rays(dataset, cam_idx)\n",
    "  val_rays = smerf.internal.datasets.preprocess_rays(\n",
    "      rays=val_rays, mode='test', merf_config=smerf_config, dataset=dataset\n",
    "  )\n",
    "\n",
    "  # Render the teacher.\n",
    "  teacher_rendering = camp_zipnerf.internal.models.render_image(\n",
    "      functools.partial(\n",
    "          teacher_render_eval_pfn,\n",
    "          teacher_pstate.params,\n",
    "          1.0,\n",
    "          None,  # No cameras needed\n",
    "      ),\n",
    "      rays=val_rays,\n",
    "      rng=None,\n",
    "      config=config,\n",
    "      return_all_levels=True,\n",
    "  )\n",
    "\n",
    "  # Render the student.\n",
    "  smerf_rendering = smerf.internal.models.render_image(\n",
    "      functools.partial(\n",
    "          smerf_render_eval_pfn,\n",
    "          teacher_pstate.params,\n",
    "          smerf_pstate.params,\n",
    "          1.0,\n",
    "      ),\n",
    "      rays=val_rays,\n",
    "      rng=None,\n",
    "      config=smerf_config,\n",
    "      verbose=False,\n",
    "  )\n",
    "\n",
    "  # Calculate step\n",
    "  step = int(smerf_pstate.step[0] // smerf_config.gradient_accumulation_steps)\n",
    "\n",
    "  # Visualize\n",
    "  media.show_images(\n",
    "      {\n",
    "          'gt': dataset.images[cam_idx],\n",
    "          'teacher': teacher_rendering['rgb'],\n",
    "          'student': smerf_rendering['rgb'],\n",
    "      },\n",
    "      ylabel=f'{cam_idx=} {step=}',\n",
    "      width=2000 // 3,\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fa2ccec-ba50-4773-be99-1765616d624a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Train\n",
    "#\n",
    "# This section will train a SMERF model, printing the train losses every `PRINT_EVERY` \n",
    "# steps and rendering a single test camera every `RENDER_EVERY` steps.\n",
    "#\n",
    "\n",
    "PRINT_EVERY = 10     # Print loss every N steps\n",
    "RENDER_EVERY = 100   # Render a test image every N steps\n",
    "\n",
    "def train(pstate):\n",
    "  teacher_config = config\n",
    "  # Prepare dataset for iterating\n",
    "  p_raybatcher = flax.jax_utils.prefetch_to_device(raybatcher, 3)\n",
    "  prng = jax.random.split(jax.random.PRNGKey(1234567), jax.local_device_count())\n",
    "\n",
    "  step = int(state.step) // smerf_config.gradient_accumulation_steps\n",
    "  for i in range(step, smerf_config.max_steps):\n",
    "    # Compute fraction of training complete.\n",
    "    train_frac = np.clip((i - 1) / (smerf_config.max_steps - 1), 0, 1)\n",
    "\n",
    "    for j in range(smerf_config.gradient_accumulation_steps):\n",
    "      pbatch = next(p_raybatcher)\n",
    "\n",
    "      # Cast rays\n",
    "      pbatch = pbatch.replace(\n",
    "          rays=smerf.internal.datasets.preprocess_rays(\n",
    "              rays=pbatch.rays,\n",
    "              mode='train',\n",
    "              merf_config=smerf_config,\n",
    "              dataset=dataset,\n",
    "              pcameras=pcameras,\n",
    "              prng=prng,\n",
    "          ),\n",
    "      )\n",
    "\n",
    "      # Push ray origins forward along camera rays.\n",
    "      pbatch = smerf.internal.train_utils.pshift_batch_forward(\n",
    "          prng=prng,\n",
    "          pbatch=pbatch,\n",
    "          teacher_pstate=teacher_pstate,\n",
    "          prender_teacher=prender_teacher,\n",
    "          config=smerf_config,\n",
    "      )\n",
    "\n",
    "      # Render teacher.\n",
    "      teacher_prng = prng if smerf_config.distill_teacher_use_rng else None\n",
    "      pteacher_history = prender_teacher(teacher_prng, teacher_pstate, pbatch)\n",
    "\n",
    "      # Update SMERF parameters.\n",
    "      pstate, pstats, prng = train_pstep(\n",
    "          prng, pstate, pbatch, pteacher_history, train_frac\n",
    "      )\n",
    "\n",
    "    # Print\n",
    "    if i % PRINT_EVERY == 0:\n",
    "      stats = flax.jax_utils.unreplicate(pstats)\n",
    "      psnr = {k: f\"{mse_to_psnr(v):0.2f}\" for k, v in stats['losses'].items()}\n",
    "      print(f'{i:05d}: {psnr}')\n",
    "\n",
    "    # Render a test frame.\n",
    "    if i % RENDER_EVERY == 0:\n",
    "      render_example(test_dataset, 0, teacher_pstate, pstate)\n",
    "\n",
    "    yield pstate\n",
    "\n",
    "# Run train loop\n",
    "pstate_iter = train(pstate)\n",
    "for pstate in pstate_iter:\n",
    "  pass\n",
    "print('Done!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1768f0d-7373-4073-b63c-2b8644a5f4d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
